
library(foreach)
library(doParallel)

pacman::p_load(tidyverse,
               fs,
               fst,
               assertthat,
               tictoc)


# Training ----------------------------------------------------------------

training_qtrs <- seq.Date(as.Date(TRAINING_START_QTR), as.Date(TRAINING_END_QTR), by = "quarter")

file_info <- tibble(file_name = dir_ls("../../data_qtr_rand_no_cid_with_outcome//")) %>% 
  mutate(
    qtr = ymd(str_extract(file_name, "[0-9]{8}")),
    rand_no_cid = str_extract(file_name, "[0-9]{4}(?=.fst)")
  )

training_files <- file_info %>% 
  filter(qtr %in% training_qtrs, rand_no_cid %in% RAND_NO_CID) %>% 
  mutate(
    year = year(qtr)
  )

are_equal(unique(training_files$rand_no_cid), RAND_NO_CID) %>% 
  stopifnot("Desired test IDs not present in pulled files" = .)

tic(msg = "Loading Training")
df_training_raw <- map_dfr(training_files$file_name, function(x) {
  print(x) 
  read_fst(x, columns = c("cid", "qtr", "riskscore", "t_default", "consumer_category", "VARNAME", "VARNAME")) %>% 
    mutate(
      rand_no_cid = str_extract(x, "[0-9]{4}(?=.fst)"),
      t_default = factor(t_default, levels = c(0, 1))
    ) %>% 
    filter(consumer_category %in% CONSUMER_GROUP)
})
toc()

df_training_rand_no_cid <-  df_training_raw %>% 
  rename(
    num_accounts = VARNAME
  ) %>% 
  mutate(
    qtr_rand_no_cid = paste0(str_replace_all(qtr, "-", ""), "_", rand_no_cid),
    age_in_years_oldest_account = VARNAME / 12,
    across(
      .cols = c(num_accounts, age_in_years_oldest_account), ~ replace_na(., 99999)
    ),
    is_thick = case_when(
      num_accounts < 9 | age_in_years_oldest_account < 10 ~ 0,
      TRUE ~ 1
    )
  ) %>% 
  select(cid, qtr, riskscore, t_default, consumer_category, rand_no_cid, qtr_rand_no_cid, is_thick)

v_unique_qtr_rand_no_cid <- df_training_rand_no_cid %>% 
  pull(qtr_rand_no_cid) %>% 
  unique()

Save_Fitted_Values <- function(df_fitted_values, qtr_rand_no_cids) {
  
  df_fitted_values %>% 
    filter(qtr_rand_no_cid == qtr_rand_no_cids) %>% 
    select(cid, qtr, riskscore, t_default, is_thick) %>% 
    write_fst(., paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/", "fitted_training_riskscore_",
                        TRAINING_SUFFIX, "_", RAND_NO_CID_SMALLEST_LARGEST, "_", MODEL_METRIC, "/" , "fitted_", qtr_rand_no_cids, ".fst"))
  
}

walk(v_unique_qtr_rand_no_cid, ~ Save_Fitted_Values(df_training_rand_no_cid, .))

# Test --------------------------------------------------------------------

test_qtrs <- seq.Date(as.Date(TEST_START_QTR), as.Date(TEST_END_QTR), by = "quarter")

file_info <- tibble(file_name = dir_ls("../../data_qtr_rand_no_cid_with_outcome//")) %>% 
  mutate(
    qtr = ymd(str_extract(file_name, "[0-9]{8}")),
    rand_no_cid = str_extract(file_name, "[0-9]{4}(?=.fst)")
  )

test_files <- file_info %>% 
  filter(qtr %in% test_qtrs, rand_no_cid %in% RAND_NO_CID) %>% 
  mutate(
    year = year(qtr)
  )

are_equal(unique(test_files$rand_no_cid), RAND_NO_CID) %>% 
  stopifnot("Desired test IDs not present in pulled files" = .)

tic(msg = "Loading Test")
df_test_raw <- map_dfr(test_files$file_name, function(x) {
  print(x) 
  read_fst(x, columns = c("cid", "qtr", "riskscore", "t_default", "consumer_category", "VARNAME", "VARNAME")) %>% 
    mutate(
      rand_no_cid = str_extract(x, "[0-9]{4}(?=.fst)"),
      t_default = factor(t_default, levels = c(0, 1))
    ) %>% 
    filter(consumer_category %in% CONSUMER_GROUP)
})
toc()

df_test_rand_no_cid <-  df_test_raw %>% 
  rename(
    num_accounts = VARNAME
  ) %>% 
  mutate(
    qtr_rand_no_cid = paste0(str_replace_all(qtr, "-", ""), "_", rand_no_cid),
    age_in_years_oldest_account = VARNAME / 12,
    across(
      .cols = c(num_accounts, age_in_years_oldest_account), ~ replace_na(., 99999)
    ),
    is_thick = case_when(
      num_accounts < 9 | age_in_years_oldest_account < 10 ~ 0,
      TRUE ~ 1
    )
  ) %>% 
  select(cid, qtr, riskscore, t_default, consumer_category, rand_no_cid, qtr_rand_no_cid, is_thick)

v_unique_qtr_rand_no_cid <- df_test_rand_no_cid %>% 
  pull(qtr_rand_no_cid) %>% 
  unique()

# cl <- makeCluster(7)
# registerDoParallel(cl)

tic(msg = "Total Loop Time")
# foreach(qtr_rand_no_cid_ = v_unique_qtr_rand_no_cid, .packages = c("tidyverse", "fst", "tictoc"),
#         .verbose = TRUE) %do% {
          
for(qtr_rand_no_cid_ in v_unique_qtr_rand_no_cid) {
          tic()
          df_test_rand_no_cid %>% 
            filter(qtr_rand_no_cid %in% qtr_rand_no_cid_) %>% 
            select(cid, qtr, riskscore, t_default, is_thick) %>% 
            write_fst(., paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/", "fitted_riskscore_",
                                        SUFFIX_FITTED_, "/" , "fitted_", qtr_rand_no_cid_, ".fst"))
          toc()
          
        }
toc()

# stopCluster(cl)


